{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "gpuType": "T4",
      "authorship_tag": "ABX9TyMZH6DyrPy3+/Sduu4e1bhA",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/greengerong/awesome-llm/blob/main/colab/meta_llama_Llama_2_7b_hf_trl%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 安装 trl\n",
        "\n",
        "参考链接：[Llama 2 来袭 - 在 Hugging Face 上玩转它](https://huggingface.co/blog/zh/llama2)"
      ],
      "metadata": {
        "id": "HW1DMp3-AYCk"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N-yRXraS_y-B"
      },
      "outputs": [],
      "source": [
        "!cd /content/\n",
        "!pip install trl peft bitsandbytes transformers huggingface_hub xformers\n",
        "!git clone https://github.com/lvwerra/trl\n",
        "!pip install -r ./trl/requirements.txt\n",
        "!huggingface-cli login"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 运行模型"
      ],
      "metadata": {
        "id": "Ne7WzVv-KBAm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "import transformers\n",
        "import torch\n",
        "\n",
        "model = \"meta-llama/Llama-2-7b-chat-hf\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model)\n",
        "pipeline = transformers.pipeline(\n",
        "    \"text-generation\",\n",
        "    model=model,\n",
        "    torch_dtype=torch.float16,\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "\n",
        "sequences = pipeline(\n",
        "    'I liked \"Breaking Bad\" and \"Band of Brothers\". Do you have any recommendations of other shows I might like?\\n',\n",
        "    do_sample=True,\n",
        "    top_k=10,\n",
        "    num_return_sequences=1,\n",
        "    eos_token_id=tokenizer.eos_token_id,\n",
        "    max_length=200,\n",
        ")\n",
        "for seq in sequences:\n",
        "    print(f\"Result: {seq['generated_text']}\")\n"
      ],
      "metadata": {
        "id": "uX1chtY6KBwN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# trl微调"
      ],
      "metadata": {
        "id": "Cg4UUlBTAh4M"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python trl/examples/scripts/sft_trainer.py \\\n",
        "    --model_name meta-llama/Llama-2-7b-hf \\\n",
        "    --dataset_name timdettmers/openassistant-guanaco \\\n",
        "    --load_in_4bit \\\n",
        "    --use_peft \\\n",
        "    --batch_size 4 \\\n",
        "    --gradient_accumulation_steps 2\n"
      ],
      "metadata": {
        "id": "jHwOpY9R_84S"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}